Skip to content

[Model] Add support for Jina Embeddings V4 #20802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

sigridjineth
Copy link

@sigridjineth sigridjineth commented Jul 11, 2025

Purpose

This PR adds support for the Jina Embeddings V4 model
(jinaai/jina-embeddings-v4-vllm-retrieval) in vLLM, enabling multimodal embeddings for text and image inputs.

FIX #20463

Test Plan

  • Run the comprehensive test suite for Jina V4
    pytest
    tests/models/multimodal/test_jina_embeddings_v4.py -v

## Documentation Update

  The following documentation has been added:

  1. Model Registry (vllm/model_executor/models/registry.py): Added JinaVLForEmbedding entry
  2. Supported Models (docs/source/models/supported_models.rst): Added Jina Embeddings V4 to the multimodal embeddings section
  3. Example Script (examples/offline_inference/jina_embeddings_v4.py):
  4. Implementation Guide  (JINA_EMBEDDINGS_V4_IMPLEMENTATION.md): Detailed
  technical documentation

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @sigridjineth, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands vLLM's capabilities by integrating native support for the Jina Embeddings V4 multimodal model. My work ensures that users can now efficiently generate embeddings for both text and image inputs, benefiting from highly optimized processing, robust error handling, and comprehensive testing to guarantee accuracy and stability.

Highlights

  • Jina Embeddings V4 Integration: I've added full, production-ready support for the Jina Embeddings V4 model (jinaai/jina-embeddings-v4-vllm-retrieval), enabling both text and image multimodal embeddings within vLLM.
  • Optimized Multimodal Pooling: The implementation includes thread-safe, token-type-aware pooling, leveraging optimized Triton CUDA kernels for efficient vision token extraction, with a robust PyTorch fallback.
  • Robustness and Observability: I've incorporated comprehensive error handling, including graceful fallback mechanisms and OOM recovery during pooling, alongside full observability integration for performance metrics.
  • Comprehensive Testing & Examples: New unit tests, a validation benchmark against HuggingFace, and an offline inference example script have been added to ensure accuracy and demonstrate usage.
  • Documentation Updates: The model has been registered in the model registry, added to the supported models documentation, and a detailed implementation guide has been provided.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues labels Jul 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR adds production-ready support for the Jina Embeddings V4 model. I've identified a bug in the tests, a performance issue in the core implementation, and some areas for code improvement in the example and validation scripts.

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing! Can you add this model to the test registry and supported models page?



# Triton kernel for optimized vision token extraction
if HAS_TRITON:
Copy link
Member

@DarkLight1337 DarkLight1337 Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much is the performance increase using triton that this additional complexity is justified? cc @Isotr0py @imkero

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would provide Triton performance benchmarks after finshing up some tasks in the pr

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this triton kernel is only used in pooler, I think the performance improvement will be very little. But it would be best to have a performance benchmarks first.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you perform benchmarking on this?

Copy link
Collaborator

@Isotr0py Isotr0py Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sigridjineth I did some benchmarks between triton kernels and torch native implementation on RTX3090, but found that triton kernel can be much slower when image seq_len is quite long, which can be normal image input case for Qwen2-VL like model:

Benchmark results
Image sequence length: 512, Text sequence length: 2048, Number of images: 1
-- triton vision pooling = 0.08771181106567383
-- native vision pooling = 0.05670571327209473

Image sequence length: 1024, Text sequence length: 2048, Number of images: 1
-- triton vision pooling = 0.10277390480041504
-- native vision pooling = 0.03438115119934082

Image sequence length: 8192, Text sequence length: 2048, Number of images: 1
-- triton vision pooling = 0.3178141117095947
-- native vision pooling = 0.07503867149353027

Image sequence length: 16384, Text sequence length: 2048, Number of images: 1
-- triton vision pooling = 0.5705935955047607
-- native vision pooling = 0.11778688430786133

Image sequence length: 512, Text sequence length: 2048, Number of images: 2
-- triton vision pooling = 0.09008479118347168
-- native vision pooling = 0.03199028968811035

Image sequence length: 1024, Text sequence length: 2048, Number of images: 2
-- triton vision pooling = 0.10735464096069336
-- native vision pooling = 0.03523516654968262

Image sequence length: 8192, Text sequence length: 2048, Number of images: 2
-- triton vision pooling = 0.3502342700958252
-- native vision pooling = 0.0757303237915039

Image sequence length: 16384, Text sequence length: 2048, Number of images: 2
-- triton vision pooling = 0.6468491554260254
-- native vision pooling = 0.12034487724304199

Image sequence length: 512, Text sequence length: 2048, Number of images: 4
-- triton vision pooling = 0.09511590003967285
-- native vision pooling = 0.03257870674133301

Image sequence length: 1024, Text sequence length: 2048, Number of images: 4
-- triton vision pooling = 0.11696052551269531
-- native vision pooling = 0.03539228439331055

Image sequence length: 8192, Text sequence length: 2048, Number of images: 4
-- triton vision pooling = 0.4277994632720947
-- native vision pooling = 0.07425379753112793

Image sequence length: 16384, Text sequence length: 2048, Number of images: 4
-- triton vision pooling = 0.8103950023651123
-- native vision pooling = 0.11885881423950195

Any idea about this? The benchmark script can be found here: https://gist.github.com/Isotr0py/eef7470ff176a28ac40340b883cf1abe



# Triton kernel for optimized vision token extraction
if HAS_TRITON:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this triton kernel is only used in pooler, I think the performance improvement will be very little. But it would be best to have a performance benchmarks first.

Comment on lines 48 to 59
@triton.jit
def extract_vision_tokens_kernel(
hidden_states_ptr,
token_ids_ptr,
output_ptr,
seq_start,
seq_len,
hidden_size,
vision_start_id: tl.constexpr,
vision_end_id: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like putting triton kernel in model implementation, we should move this to pooler.py or somewhere else if the performance improvement is significant.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done caea1fe

@sigridjineth sigridjineth force-pushed the jina-support branch 3 times, most recently from 4eb5e88 to caea1fe Compare July 11, 2025 14:39
@sigridjineth
Copy link
Author

@Isotr0py @DarkLight1337 do review if more changes needed if you think so

@DarkLight1337
Copy link
Member

Sorry for the delay, can you merge from main and fix pre-commit?

@DarkLight1337 DarkLight1337 added this to the v0.10.0 milestone Jul 16, 2025
Address DarkLight1337's review feedback:
- Set logits_processing_needs_token_ids=True for V1 compatibility in both
  "embed" and "encode" tasks
- Support "encode" task by returning PoolingParams() instead of None
- Update log message from "thread-safe pooling" to "vision-aware pooling"
  to better reflect the actual functionality
- Remove unused seq_ids variable from _extract_token_ids_safe method

These changes ensure proper V1 compatibility and cleaner code structure.

Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>

for i in range(seq_len):
token_id = tl.load(token_ids_ptr + seq_start + i)
if token_id >= vision_start_id and token_id <= vision_end_id:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if token_id >= vision_start_id and token_id <= vision_end_id:
if token_id in (vision_start_id, vision_end_id):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually now that I think of it, this isn't quite correct? The start index should be the first item that equals vision_start_id, and then all subsequent tokens (regardless of ID) are included until vision_end_id is found

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to your code, you only selected the tokens corresponding to vision_start_id and vision_end_id, but not the tokens in between them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @DarkLight1337, thanks for highlighting this. reflected again and got that your assessment is correct. there's a positional indexing mistake in the implementation.

Copy link
Author

@sigridjineth sigridjineth Jul 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem was I think it incorrectly selects only tokens exactly matching the vision_start_id and vision_end_id.

It fails to select intermediate tokens between these markers because it uses direct ID matching instead of positional masking. I have created on the commit to introduce the dedicated VisionPooler class that finds the positions of vision_start_id and vision_end_id via torch.where method. the goal is to ensure pooling the entire positional range.

would like to get your feedback on this 5114a3c

Implement efficiency improvements suggested by DarkLight1337:
- Consolidate get_pooling_params method for "embed" and "encode" tasks
- Pre-compute vision token IDs tensor in constructor
- Replace range checks with torch.isin for more efficient vision token detection
  at lines 209-210 and 261-262

This reduces redundant code and improves performance when checking for
vision tokens by using optimized tensor operations.

Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>

def extract_embeddings(output):
"""Extract embeddings based on token type."""
if VISION_START_TOKEN_ID in output.prompt_token_ids:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test the whole embedding tensor against HF to avoid these kinds of mistakes

Comment on lines 69 to 78
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
"""Return pooling params for embedding task."""
if task == "embed" or task == "encode":
return PoolingParams(logits_processing_needs_token_ids=True)

# The equalities are split up to keep mypy happy
if task == "classify" or task == "score":
return None

assert_never(task)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you merge from main and then apply this update? (Need to update the imports accordingly as well)

Suggested change
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
"""Return pooling params for embedding task."""
if task == "embed" or task == "encode":
return PoolingParams(logits_processing_needs_token_ids=True)
# The equalities are split up to keep mypy happy
if task == "classify" or task == "score":
return None
assert_never(task)
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if task == "encode" or task == "embed":
return PoolingParamsUpdate(requires_token_ids=True)
if task == "classify" or task == "score":
return None
assert_never(task)

Comment on lines 132 to 175
def _extract_token_ids_safe(
self, pooling_metadata: PoolingMetadata) -> list[array]:
"""Safely extract token IDs from pooling metadata."""
token_ids_list: list[array] = []
try:
if isinstance(pooling_metadata, V1PoolingMetadata):
# For V1, we get token IDs directly
for i, num in enumerate(pooling_metadata.prompt_lens):
token_ids = pooling_metadata.prompt_token_ids[
i, :num].tolist()
token_ids_list.append(array('l', token_ids))

return token_ids_list

# For V0, we extract from seq_groups and seq_data
for seq_group, _ in pooling_metadata.seq_groups:
for seq_id in seq_group:
if seq_id not in pooling_metadata.seq_data:
logger.warning("Sequence %s not found in seq_data",
seq_id)
continue

seq_data = pooling_metadata.seq_data[seq_id]

# Get prompt token IDs safely
if hasattr(seq_data, 'prompt_token_ids_array'):
token_ids = seq_data.prompt_token_ids_array
elif hasattr(seq_data, '_prompt_token_ids'):
token_ids = seq_data._prompt_token_ids
else:
logger.warning("No token IDs found for sequence %s",
seq_id)
continue

token_ids_list.append(token_ids)

return token_ids_list

except Exception as e:
logger.error(
"Error extracting token IDs: %s. "
"Extracted %d sequences before failure", e,
len(token_ids_list))
raise
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest main now has get_prompt_token_ids in pooler.py which can replace this functionality (but note that it outputs torch.Tensor instead of array.array)

Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
@@ -32,6 +32,7 @@ class PoolingType(IntEnum):
CLS = 2
STEP = 3
MEAN = 4
VISION = 5
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have created this new type of vision pooling for PoolingClass.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import gc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you test the correctness of the model against the HF implementation?

Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
@@ -3258,7 +3258,8 @@ def get_limit_per_prompt(self, modality: str) -> int:
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""

pooling_type: Optional[str] = None
pooling_type: Optional[Literal["last", "all", "cls", "step", "mean",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the pooling type here is supposed to be upper case

def from_config(cls, model_config: ModelConfig) -> "VisionPooler":
return cls(model_config)

def __init__(self, config: ModelConfig):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass in the token IDs and hidden size explicitly? In case other models store those attributes in different locations

super().__init__(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "qwen2_vl"))

self.pooler = JinaVLPooler(vllm_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not directly use VisionPooler here?

Copy link

mergify bot commented Jul 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sigridjineth.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) needs-rebase new-model Requests to new models performance Performance-related issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[New Model]: jinaai/jina-embeddings-v4
3 participants